import openpyxl

class DataImporter:
    """
    负责导入xlsx数据，并根据列名与元素name的映射关系生成测试数据json
    """
    def __init__(self, xls_path, col_mapping, sheet_name=None):
        self.xls_path = xls_path
        self.col_mapping = col_mapping
        self.sheet_name = sheet_name

    def import_data(self):
        workbook = openpyxl.load_workbook(self.xls_path)
        if self.sheet_name:
            sheet = workbook[self.sheet_name]
        else:
            sheet = workbook.active
        header = [cell.value for cell in next(sheet.iter_rows(min_row=1, max_row=1))]
        data_list = []
        for row in sheet.iter_rows(min_row=2, values_only=True):
            data = {}
            for col_name, name in self.col_mapping.items():
                if col_name in header:
                    col_idx = header.index(col_name)
                    data[name] = row[col_idx]
            data_list.append(data)
        return data_list

    def iter_test_data(self):
        for data in self.import_data():
            yield data